import time
import numpy as np
import torch
import cv2
from models.experimental import attempt_load
from utils.datasets import letterbox
from utils.plots import plot_one_box
from utils.general import check_img_size, non_max_suppression, scale_coords
from utils.torch_utils import select_device, time_synchronized


class PlateParser(object):
    def __init__(self, model_path_name, device_cpu_gpu_type):
        # 模型加载, 获取设备
        self.device = select_device(device_cpu_gpu_type)
        # 如果设备为gpu，使用Float16, half precision only supported on CUDA
        self.half = self.device.type != 'cpu'
        # 加载Float32模型
        self.model = attempt_load(model_path_name, map_location=self.device)
        # 设置Float16
        if self.half:
            self.model.half()
        self.model_default_image_size = 640
        # 进行一次前向推理,测试程序是否正常 run once
        img = torch.zeros((1, 3, self.model_default_image_size, self.model_default_image_size),
                          device=self.device)  # init img
        _ = self.model(img.half() if self.half else img) if self.device.type != 'cpu' else None  # run once

        # 获取类别名字
        self.model_classes_names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
        print('names=', self.model_classes_names)

    def plate_detect(self, img_np, conf_threshold=0.25):
        start = time.time()
        stride = int(self.model.stride.max())  # model stride =32
        # 确保用户设定的输入图片分辨率能整除32(如不能则调整为能整除并返回) 网络输入图片大小
        img_size_adjust = check_img_size(self.model_default_image_size, s=stride)  # check img_size
        # Resize and pad image while meeting stride-multiple constraints 取其第一个返回值
        img = letterbox(img_np, new_shape=img_size_adjust)[0]
        # Convert
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).to(self.device)
        # 图片也设置为Float16
        img = img.half() if self.half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        # 没有batch_size的话则在最前面添加一个轴
        if img.ndimension() == 3:
            img = img.unsqueeze(0)
        # Inference 推理
        t1 = time_synchronized()  # 是否需要?
        pred = self.model(img)[0]
        # Apply NMS
        pred = non_max_suppression(pred, conf_thres=conf_threshold)
        t2 = time_synchronized()
        # Process detections
        resultList = []
        for i, det in enumerate(pred):  # detections per image
            # len(det) 即检测到几个框
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img_np.shape).round()
                # Write results
                for *xyxy, conf, cls in reversed(det):
                    top1 = {'location': {'x1': int(xyxy[0]),
                                         'x2': int(xyxy[1]),
                                         'y1': int(xyxy[2]),
                                         'y2': int(xyxy[3])
                                         },
                            'plates': {
                                'score': round(float(conf), 4),
                                'plates': self.model_classes_names[int(cls)]
                            }
                            }
                    resultList.append(top1)
                    c1, c2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3]))
                    print(c1, c2, 'name=', self.model_classes_names[int(cls)], 'socre=', round(float(conf), 2))
                    label = f'{self.model_classes_names[int(cls)]} {conf:.2f}'
                    plot_one_box(xyxy, img_np, label=label)
                cv2.imwrite('./tmp.jpg', img_np)

        print('model detect interval=', round(time.time() - start, 3))
        return resultList
